/*
 * SPDX-License-Identifier: Apache-2.0
 *
 * The OpenSearch Contributors require contributions made to
 * this file be licensed under the Apache-2.0 license or a
 * compatible open source license.
 */

package org.opensearch.plugin.kafka;

import org.opensearch.action.admin.cluster.node.info.NodeInfo;
import org.opensearch.action.admin.cluster.node.info.NodesInfoRequest;
import org.opensearch.action.admin.cluster.node.info.NodesInfoResponse;
import org.opensearch.action.admin.cluster.node.info.PluginsAndModules;
import org.opensearch.action.admin.cluster.snapshots.create.CreateSnapshotResponse;
import org.opensearch.action.admin.indices.stats.IndexStats;
import org.opensearch.action.admin.indices.stats.ShardStats;
import org.opensearch.action.admin.indices.streamingingestion.pause.PauseIngestionResponse;
import org.opensearch.action.admin.indices.streamingingestion.resume.ResumeIngestionResponse;
import org.opensearch.action.admin.indices.streamingingestion.state.GetIngestionStateResponse;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.routing.allocation.command.AllocateReplicaAllocationCommand;
import org.opensearch.cluster.routing.allocation.command.MoveAllocationCommand;
import org.opensearch.common.settings.Settings;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.RangeQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.indices.pollingingest.PollingIngestStats;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.test.InternalTestCluster;
import org.opensearch.test.OpenSearchIntegTestCase;
import org.opensearch.transport.client.Requests;
import org.junit.Assert;

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertAcked;
import static org.hamcrest.Matchers.is;
import static org.awaitility.Awaitility.await;

/**
 * Integration test for Kafka ingestion.
 */
@OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.TEST, numDataNodes = 0)
public class IngestFromKafkaIT extends KafkaIngestionBaseIT {
    /**
     * test ingestion-kafka-plugin is installed
     */
    public void testPluginsAreInstalled() {
        NodesInfoRequest nodesInfoRequest = new NodesInfoRequest();
        nodesInfoRequest.addMetric(NodesInfoRequest.Metric.PLUGINS.metricName());
        NodesInfoResponse nodesInfoResponse = OpenSearchIntegTestCase.client().admin().cluster().nodesInfo(nodesInfoRequest).actionGet();
        List<PluginInfo> pluginInfos = nodesInfoResponse.getNodes()
            .stream()
            .flatMap(
                (Function<NodeInfo, Stream<PluginInfo>>) nodeInfo -> nodeInfo.getInfo(PluginsAndModules.class).getPluginInfos().stream()
            )
            .collect(Collectors.toList());
        Assert.assertTrue(
            pluginInfos.stream().anyMatch(pluginInfo -> pluginInfo.getName().equals("org.opensearch.plugin.kafka.KafkaPlugin"))
        );
    }

    public void testKafkaIngestion() {
        produceData("1", "name1", "24");
        produceData("2", "name2", "20");
        createIndexWithDefaultSettings(1, 0);

        RangeQueryBuilder query = new RangeQueryBuilder("age").gte(21);
        await().atMost(10, TimeUnit.SECONDS).untilAsserted(() -> {
            refresh(indexName);
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
            PollingIngestStats stats = client().admin().indices().prepareStats(indexName).get().getIndex(indexName).getShards()[0]
                .getPollingIngestStats();
            assertNotNull(stats);
            assertThat(stats.getMessageProcessorStats().totalProcessedCount(), is(2L));
            assertThat(stats.getConsumerStats().totalPolledCount(), is(2L));
        });
    }

    public void testKafkaIngestion_RewindByTimeStamp() {
        produceData("1", "name1", "24", 1739459500000L, "index");
        produceData("2", "name2", "20", 1739459800000L, "index");

        // create an index with ingestion source from kafka
        createIndex(
            "test_rewind_by_timestamp",
            Settings.builder()
                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
                .put("ingestion_source.type", "kafka")
                .put("ingestion_source.pointer.init.reset", "reset_by_timestamp")
                // 1739459500000 is the timestamp of the first message
                // 1739459800000 is the timestamp of the second message
                // by resetting to 1739459600000, only the second message will be ingested
                .put("ingestion_source.pointer.init.reset.value", "1739459600000")
                .put("ingestion_source.param.topic", "test")
                .put("ingestion_source.param.bootstrap_servers", kafka.getBootstrapServers())
                .put("ingestion_source.param.auto.offset.reset", "latest")
                .build(),
            "{\"properties\":{\"name\":{\"type\": \"text\"},\"age\":{\"type\": \"integer\"}}}}"
        );

        RangeQueryBuilder query = new RangeQueryBuilder("age").gte(0);
        await().atMost(10, TimeUnit.SECONDS).untilAsserted(() -> {
            refresh("test_rewind_by_timestamp");
            SearchResponse response = client().prepareSearch("test_rewind_by_timestamp").setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
        });
    }

    public void testKafkaIngestion_RewindByOffset() {
        produceData("1", "name1", "24");
        produceData("2", "name2", "20");
        // create an index with ingestion source from kafka
        createIndex(
            "test_rewind_by_offset",
            Settings.builder()
                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0)
                .put("ingestion_source.type", "kafka")
                .put("ingestion_source.pointer.init.reset", "reset_by_offset")
                .put("ingestion_source.pointer.init.reset.value", "1")
                .put("ingestion_source.param.topic", "test")
                .put("ingestion_source.param.bootstrap_servers", kafka.getBootstrapServers())
                .put("ingestion_source.param.auto.offset.reset", "latest")
                .build(),
            "{\"properties\":{\"name\":{\"type\": \"text\"},\"age\":{\"type\": \"integer\"}}}}"
        );

        RangeQueryBuilder query = new RangeQueryBuilder("age").gte(0);
        await().atMost(1, TimeUnit.MINUTES).untilAsserted(() -> {
            refresh("test_rewind_by_offset");
            SearchResponse response = client().prepareSearch("test_rewind_by_offset").setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
        });
    }

    public void testCloseIndex() throws Exception {
        createIndexWithDefaultSettings(1, 0);
        ensureGreen(indexName);
        client().admin().indices().close(Requests.closeIndexRequest(indexName)).get();
    }

    public void testMessageOperationTypes() throws Exception {
        // Step 1: Produce message and wait for it to be searchable

        produceData("1", "name", "25", defaultMessageTimestamp, "index");
        createIndexWithDefaultSettings(1, 0);
        ensureGreen(indexName);
        waitForState(() -> {
            BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "1"));
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
            return 25 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age");
        });

        // Step 2: Update age field from 25 to 30 and validate

        produceData("1", "name", "30", defaultMessageTimestamp, "index");
        waitForState(() -> {
            BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "1"));
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
            return 30 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age");
        });

        // Step 3: Delete the document and validate
        produceData("1", "name", "30", defaultMessageTimestamp, "delete");
        waitForState(() -> {
            BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "1"));
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            return response.getHits().getTotalHits().value() == 0;
        });

        // Step 4: Validate create operation
        produceData("2", "name", "30", defaultMessageTimestamp, "create");
        waitForState(() -> {
            BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", "2"));
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
            return 30 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age");
        });
    }

    public void testUpdateWithoutIDField() throws Exception {
        // Step 1: Produce message without ID
        String payload = "{\"_op_type\":\"index\",\"_source\":{\"name\":\"name\", \"age\": 25}}";
        produceData(payload);

        createIndexWithDefaultSettings(1, 0);
        ensureGreen(indexName);

        waitForState(() -> {
            BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("age", "25"));
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
            return 25 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age");
        });

        SearchResponse searchableDocsResponse = client().prepareSearch(indexName).setSize(10).setPreference("_only_local").get();
        assertThat(searchableDocsResponse.getHits().getTotalHits().value(), is(1L));
        assertEquals(25, searchableDocsResponse.getHits().getHits()[0].getSourceAsMap().get("age"));
        String id = searchableDocsResponse.getHits().getHits()[0].getId();

        // Step 2: Produce an update message using retrieved ID and validate

        produceData(id, "name", "30", defaultMessageTimestamp, "index");
        waitForState(() -> {
            BoolQueryBuilder query = new BoolQueryBuilder().must(new TermQueryBuilder("_id", id));
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            assertThat(response.getHits().getTotalHits().value(), is(1L));
            return 30 == (Integer) response.getHits().getHits()[0].getSourceAsMap().get("age");
        });
    }

    public void testMultiThreadedWrites() throws Exception {
        // create index with 5 writer threads
        createIndexWithDefaultSettings(indexName, 1, 0, 5);
        ensureGreen(indexName);

        // Step 1: Produce messages
        for (int i = 0; i < 1000; i++) {
            produceData(Integer.toString(i), "name" + i, "25");
        }

        waitForState(() -> {
            SearchResponse searchableDocsResponse = client().prepareSearch(indexName).setSize(2000).setPreference("_only_local").get();
            return searchableDocsResponse.getHits().getTotalHits().value() == 1000;
        });

        // Step 2: Produce an update message and validate
        for (int i = 0; i < 1000; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        waitForState(() -> {
            RangeQueryBuilder query = new RangeQueryBuilder("age").gte(28);
            SearchResponse response = client().prepareSearch(indexName).setQuery(query).get();
            return response.getHits().getTotalHits().value() == 1000;
        });
    }

    public void testAllActiveIngestion() throws Exception {
        // Create pull-based index in default replication mode (docrep) and publish some messages

        internalCluster().startClusterManagerOnlyNode();
        final String nodeA = internalCluster().startDataOnlyNode();
        for (int i = 0; i < 10; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        createIndex(
            indexName,
            Settings.builder()
                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1)
                .put("ingestion_source.type", "kafka")
                .put("ingestion_source.param.topic", topicName)
                .put("ingestion_source.param.bootstrap_servers", kafka.getBootstrapServers())
                .put("ingestion_source.pointer.init.reset", "earliest")
                .put("ingestion_source.all_active", true)
                .build(),
            "{\"properties\":{\"name\":{\"type\": \"text\"},\"age\":{\"type\": \"integer\"}}}}"
        );

        ensureYellowAndNoInitializingShards(indexName);
        waitForSearchableDocs(10, List.of(nodeA));
        flush(indexName);

        // add a second node and verify the replica ingests the data
        final String nodeB = internalCluster().startDataOnlyNode();
        ensureGreen(indexName);
        assertTrue(nodeA.equals(primaryNodeName(indexName)));
        assertTrue(nodeB.equals(replicaNodeName(indexName)));
        waitForSearchableDocs(10, List.of(nodeB));

        // verify pause and resume functionality on replica

        // pause ingestion
        PauseIngestionResponse pauseResponse = pauseIngestion(indexName);
        assertTrue(pauseResponse.isAcknowledged());
        assertTrue(pauseResponse.isShardsAcknowledged());
        waitForState(() -> {
            GetIngestionStateResponse ingestionState = getIngestionState(indexName);
            return ingestionState.getShardStates().length == 2
                && ingestionState.getFailedShards() == 0
                && Arrays.stream(ingestionState.getShardStates())
                    .allMatch(state -> state.isPollerPaused() && state.getPollerState().equalsIgnoreCase("paused"));
        });

        for (int i = 10; i < 20; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        // replica must not ingest when paused
        Thread.sleep(1000);
        assertEquals(10, getSearchableDocCount(nodeB));

        // resume ingestion
        ResumeIngestionResponse resumeResponse = resumeIngestion(indexName);
        assertTrue(resumeResponse.isAcknowledged());
        assertTrue(resumeResponse.isShardsAcknowledged());
        waitForState(() -> {
            GetIngestionStateResponse ingestionState = getIngestionState(indexName);
            return ingestionState.getShardStates().length == 2
                && Arrays.stream(ingestionState.getShardStates())
                    .allMatch(
                        state -> state.isPollerPaused() == false
                            && (state.getPollerState().equalsIgnoreCase("polling") || state.getPollerState().equalsIgnoreCase("processing"))
                    );
        });

        // verify replica ingests data after resuming ingestion
        waitForSearchableDocs(20, List.of(nodeA, nodeB));

        // produce 10 more messages
        for (int i = 20; i < 30; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        // Add new node and wait for new node to join cluster
        final String nodeC = internalCluster().startDataOnlyNode();
        assertBusy(() -> {
            assertEquals(
                "Should have 4 nodes total (1 cluster manager + 3 data)",
                4,
                internalCluster().clusterService().state().nodes().getSize()
            );
        }, 30, TimeUnit.SECONDS);

        // move replica from nodeB to nodeC
        ensureGreen(indexName);
        client().admin().cluster().prepareReroute().add(new MoveAllocationCommand(indexName, 0, nodeB, nodeC)).get();
        ensureGreen(indexName);

        // confirm replica ingests messages after moving to new node
        waitForSearchableDocs(30, List.of(nodeA, nodeC));

        for (int i = 30; i < 40; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        // restart replica node and verify ingestion
        internalCluster().restartNode(nodeC);
        ensureGreen(indexName);
        waitForSearchableDocs(40, List.of(nodeA, nodeC));

        // Verify both primary and replica do not have failed messages
        Map<String, PollingIngestStats> shardTypeToStats = getPollingIngestStatsForPrimaryAndReplica(indexName);
        assertNotNull(shardTypeToStats.get("primary"));
        assertNotNull(shardTypeToStats.get("replica"));
        assertThat(shardTypeToStats.get("primary").getConsumerStats().totalPollerMessageDroppedCount(), is(0L));
        assertThat(shardTypeToStats.get("primary").getConsumerStats().totalPollerMessageFailureCount(), is(0L));
        // replica consumes only 10 messages after it has been restarted
        assertThat(shardTypeToStats.get("replica").getConsumerStats().totalPollerMessageDroppedCount(), is(0L));
        assertThat(shardTypeToStats.get("replica").getConsumerStats().totalPollerMessageFailureCount(), is(0L));

        GetIngestionStateResponse ingestionState = getIngestionState(indexName);
        assertEquals(2, ingestionState.getShardStates().length);
        assertEquals(0, ingestionState.getFailedShards());
    }

    public void testReplicaPromotionOnAllActiveIngestion() throws Exception {
        // Create pull-based index in default replication mode (docrep) and publish some messages
        internalCluster().startClusterManagerOnlyNode();
        final String nodeA = internalCluster().startDataOnlyNode();
        for (int i = 0; i < 10; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        createIndex(
            indexName,
            Settings.builder()
                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1)
                .put("ingestion_source.type", "kafka")
                .put("ingestion_source.param.topic", topicName)
                .put("ingestion_source.param.bootstrap_servers", kafka.getBootstrapServers())
                .put("ingestion_source.pointer.init.reset", "earliest")
                .put("ingestion_source.all_active", true)
                .build(),
            "{\"properties\":{\"name\":{\"type\": \"text\"},\"age\":{\"type\": \"integer\"}}}}"
        );

        ensureYellowAndNoInitializingShards(indexName);
        waitForSearchableDocs(10, List.of(nodeA));

        // add second node
        final String nodeB = internalCluster().startDataOnlyNode();
        ensureGreen(indexName);
        assertTrue(nodeA.equals(primaryNodeName(indexName)));
        assertTrue(nodeB.equals(replicaNodeName(indexName)));
        waitForSearchableDocs(10, List.of(nodeB));

        // Validate replica promotion
        internalCluster().stopRandomNode(InternalTestCluster.nameFilter(nodeA));
        ensureYellowAndNoInitializingShards(indexName);
        assertTrue(nodeB.equals(primaryNodeName(indexName)));
        for (int i = 10; i < 20; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        waitForSearchableDocs(20, List.of(nodeB));

        // add third node and allocate the replica once the node joins the cluster
        final String nodeC = internalCluster().startDataOnlyNode();
        assertBusy(() -> { assertEquals(3, internalCluster().clusterService().state().nodes().getSize()); }, 30, TimeUnit.SECONDS);
        client().admin().cluster().prepareReroute().add(new AllocateReplicaAllocationCommand(indexName, 0, nodeC)).get();
        ensureGreen(indexName);
        waitForSearchableDocs(20, List.of(nodeC));

    }

    public void testSnapshotRestoreOnAllActiveIngestion() throws Exception {
        // Create pull-based index in default replication mode (docrep) and publish some messages
        internalCluster().startClusterManagerOnlyNode();
        final String nodeA = internalCluster().startDataOnlyNode();
        final String nodeB = internalCluster().startDataOnlyNode();
        for (int i = 0; i < 20; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        createIndex(
            indexName,
            Settings.builder()
                .put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, 1)
                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 1)
                .put("ingestion_source.type", "kafka")
                .put("ingestion_source.param.topic", topicName)
                .put("ingestion_source.param.bootstrap_servers", kafka.getBootstrapServers())
                .put("ingestion_source.pointer.init.reset", "earliest")
                .put("ingestion_source.all_active", true)
                .build(),
            "{\"properties\":{\"name\":{\"type\": \"text\"},\"age\":{\"type\": \"integer\"}}}}"
        );
        ensureGreen(indexName);
        waitForSearchableDocs(20, List.of(nodeA, nodeB));

        // Register snapshot repository
        String snapshotRepositoryName = "test-snapshot-repo";
        String snapshotName = "snapshot-1";
        assertAcked(
            client().admin()
                .cluster()
                .preparePutRepository(snapshotRepositoryName)
                .setType("fs")
                .setSettings(Settings.builder().put("location", randomRepoPath()).put("compress", false))
        );

        // Take snapshot
        flush(indexName);
        CreateSnapshotResponse snapshotResponse = client().admin()
            .cluster()
            .prepareCreateSnapshot(snapshotRepositoryName, snapshotName)
            .setWaitForCompletion(true)
            .setIndices(indexName)
            .get();
        assertTrue(snapshotResponse.getSnapshotInfo().successfulShards() > 0);

        // Delete Index
        assertAcked(client().admin().indices().prepareDelete(indexName));
        waitForState(() -> {
            ClusterState state = client().admin().cluster().prepareState().setIndices(indexName).get().getState();
            return state.getRoutingTable().hasIndex(indexName) == false && state.getMetadata().hasIndex(indexName) == false;
        });

        for (int i = 20; i < 40; i++) {
            produceData(Integer.toString(i), "name" + i, "30");
        }

        // Restore Index from Snapshot
        client().admin()
            .cluster()
            .prepareRestoreSnapshot(snapshotRepositoryName, snapshotName)
            .setWaitForCompletion(true)
            .setIndices(indexName)
            .get();
        ensureGreen(indexName);

        refresh(indexName);
        waitForSearchableDocs(40, List.of(nodeA, nodeB));

        // Verify both primary and replica have polled only remaining 20 messages
        Map<String, PollingIngestStats> shardTypeToStats = getPollingIngestStatsForPrimaryAndReplica(indexName);
        assertNotNull(shardTypeToStats.get("primary"));
        assertNotNull(shardTypeToStats.get("replica"));
        assertThat(shardTypeToStats.get("primary").getConsumerStats().totalPolledCount(), is(20L));
        assertThat(shardTypeToStats.get("primary").getConsumerStats().totalPollerMessageDroppedCount(), is(0L));
        assertThat(shardTypeToStats.get("primary").getConsumerStats().totalPollerMessageFailureCount(), is(0L));
        assertThat(shardTypeToStats.get("replica").getConsumerStats().totalPolledCount(), is(20L));
        assertThat(shardTypeToStats.get("replica").getConsumerStats().totalPollerMessageDroppedCount(), is(0L));
        assertThat(shardTypeToStats.get("replica").getConsumerStats().totalPollerMessageFailureCount(), is(0L));
    }

    // returns PollingIngestStats for single primary and single replica
    private Map<String, PollingIngestStats> getPollingIngestStatsForPrimaryAndReplica(String indexName) {
        IndexStats indexStats = client().admin().indices().prepareStats(indexName).get().getIndex(indexName);
        ShardStats[] shards = indexStats.getShards();
        assertEquals(2, shards.length);
        Map<String, PollingIngestStats> shardTypeToStats = new HashMap<>();
        for (ShardStats shardStats : shards) {
            if (shardStats.getShardRouting().primary()) {
                shardTypeToStats.put("primary", shardStats.getPollingIngestStats());
            } else {
                shardTypeToStats.put("replica", shardStats.getPollingIngestStats());
            }
        }

        return shardTypeToStats;
    }
}
